/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSED_MHA_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSED_MHA_THUNK_H_

#include <memory>
#include <optional>

#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fused_mha_runner.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/tsl/platform/status.h"

namespace xla {
namespace gpu {

// This class stores everything that StreamExecutor needs to launch a DNN
// fMHA. It is generated by IrEmitter.
//
// This is thread-compatible.
class FusedMHAThunk : public Thunk {
 public:
  // Constructs a thunk for launching a DNN FMHA.
  FusedMHAThunk(ThunkInfo thunk_info, GpufMHAConfig config,
                BufferAllocation::Slice lhs_bmm1_slice,
                BufferAllocation::Slice rhs_bmm1_slice,
                BufferAllocation::Slice rhs_bmm2_slice,
                BufferAllocation::Slice output_slice,
                BufferAllocation::Slice scratch_slice,
                BufferAllocation::Slice mask_slice, /* may be null */
                BufferAllocation::Slice bias_slice /* may be null */,
                BufferAllocation::Slice activation_slice /* may be null */);

  FusedMHAThunk(const FusedMHAThunk&) = delete;
  FusedMHAThunk& operator=(const FusedMHAThunk&) = delete;

  Status ExecuteOnStream(const ExecuteParams& params) override;

 private:
  BufferAllocation::Slice lhs_bmm1_buffer_;
  BufferAllocation::Slice rhs_bmm1_buffer_;
  BufferAllocation::Slice rhs_bmm2_buffer_;
  BufferAllocation::Slice output_buffer_;
  BufferAllocation::Slice scratch_buffer_;
  BufferAllocation::Slice mask_buffer_;
  BufferAllocation::Slice bias_buffer_;
  BufferAllocation::Slice activation_buffer_;

  FusedMultiHeadedAttentionRunner& GetOrCreateRunner(
      const stream_executor::Stream* stream);

  // FusedMHA config
  const GpufMHAConfig config_;
  absl::Mutex mu_;
  absl::flat_hash_map<const stream_executor::Stream*,
                      std::unique_ptr<FusedMultiHeadedAttentionRunner>>
      runner_cache_ ABSL_GUARDED_BY(mu_);
};

class FusedMHABackwardThunk : public Thunk {
 public:
  // Constructs a thunk for launching a DNN FMHA backward.
  FusedMHABackwardThunk(ThunkInfo thunk_info, GpufMHABackwardConfig config,
                        BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice,
                        BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice,
                        BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice,
                        BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice,
                        BufferAllocation::Slice d_output_slice,
                        BufferAllocation::Slice scratch_slice,
                        BufferAllocation::Slice d_bmm1_lhs_slice,
                        BufferAllocation::Slice d_bmm1_rhs_slice,
                        BufferAllocation::Slice d_bmm2_rhs_slice,
                        BufferAllocation::Slice d_S_slice,
                        BufferAllocation::Slice mask_slice,
                        BufferAllocation::Slice d_bias_slice);

  FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete;
  FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete;

  Status ExecuteOnStream(const ExecuteParams& params) override;

 private:
  BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_;
  BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_;
  BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_;
  BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_;
  BufferAllocation::Slice d_output_buffer_;
  BufferAllocation::Slice scratch_buffer_;
  BufferAllocation::Slice d_bmm1_lhs_buffer_;
  BufferAllocation::Slice d_bmm1_rhs_buffer_;
  BufferAllocation::Slice d_bmm2_rhs_buffer_;
  BufferAllocation::Slice d_s_buffer_;
  BufferAllocation::Slice mask_buffer_;
  BufferAllocation::Slice d_bias_buffer_;

  FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner(
      const stream_executor::Stream* stream);

  // FusedMHA backward config
  const GpufMHABackwardConfig config_;
  absl::Mutex mu_;
  absl::flat_hash_map<const stream_executor::Stream*,
                      std::unique_ptr<FusedMultiHeadedAttentionBackwardRunner>>
      runner_cache_ ABSL_GUARDED_BY(mu_);
};
}  // namespace gpu
}  // namespace xla
#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSED_MHA_THUNK_H_
